
clear all

global dat
global dat2
global g
global D
global k


bnewall=[];
g  = 532;   % num of groups, 
grp = 0;
% Generate S draws from standard normal distribution
D=100;
randn('state',768795465)
miu=randn(g,D);  %group-specific random effect
N=0;

    for i = 1:g
  
      %%% Load X and Y,  %%%
    c1=['part1\data_part1_' int2str(i)];
    c2=['part2\data_part2_' int2str(i)];
    c3=['part3\data_part3_' int2str(i)];
    c4=['part4\data_part4_' int2str(i)];
    
    % add part 5 to part 36 for the school dummies.
    c5=['part5\data_part5_' int2str(i)];
    c6=['part6\data_part6_' int2str(i)];
    c7=['part7\data_part7_' int2str(i)];
    c8=['part8\data_part8_' int2str(i)];
    c9=['part9\data_part9_' int2str(i)];
    c10=['part10\data_part10_' int2str(i)];
    c11=['part11\data_part11_' int2str(i)];
    c12=['part12\data_part12_' int2str(i)];
    c13=['part13\data_part13_' int2str(i)];
    c14=['part14\data_part14_' int2str(i)];
    c15=['part15\data_part15_' int2str(i)];
    c16=['part16\data_part16_' int2str(i)];
    c17=['part17\data_part17_' int2str(i)];
    c18=['part18\data_part18_' int2str(i)];
    c19=['part19\data_part19_' int2str(i)];
    c20=['part20\data_part20_' int2str(i)];
    c21=['part21\data_part21_' int2str(i)];
    c22=['part22\data_part22_' int2str(i)];
    c23=['part23\data_part23_' int2str(i)];
    c24=['part24\data_part24_' int2str(i)];
    c25=['part25\data_part25_' int2str(i)];
    c26=['part26\data_part26_' int2str(i)];
    c27=['part27\data_part27_' int2str(i)];
    c28=['part28\data_part28_' int2str(i)];
    c29=['part29\data_part29_' int2str(i)];
    c30=['part30\data_part30_' int2str(i)];
    c31=['part31\data_part31_' int2str(i)];
    c32=['part32\data_part32_' int2str(i)];
    c33=['part33\data_part33_' int2str(i)];
    c34=['part34\data_part34_' int2str(i)];
    c35=['part35\data_part35_' int2str(i)];
    c36=['part36\data_part36_' int2str(i)];
   
 
    d1=load([c1 '.txt']);
    d2=load([c2 '.txt']);
    d3=load([c3 '.txt']);
    d4=load([c4 '.txt']);
    
    % add the school dummies 
    d5= load([ c5 '.txt']);  
    d6= load([ c6 '.txt']); 
    d7= load([ c7 '.txt']);  
    d8= load([ c8 '.txt']); 
    d9= load([ c9 '.txt']); 
    d10=load([ c10  '.txt']); 
    d11=load([ c11  '.txt']); 
    d12=load([ c12  '.txt']); 
    d13=load([ c13  '.txt']); 
    d14=load([ c14  '.txt']); 
    d15=load([ c15  '.txt']); 
    d16=load([ c16  '.txt']); 
    d17=load([ c17  '.txt']); 
    d18=load([ c18  '.txt']); 
    d19=load([ c19  '.txt']); 
    d20=load([ c20  '.txt']); 
    d21=load([ c21  '.txt']); 
    d22=load([ c22  '.txt']); 
    d23=load([ c23  '.txt']); 
    d24=load([ c24  '.txt']); 
    d25=load([ c25  '.txt']); 
    d26=load([ c26  '.txt']); 
    d27=load([ c27  '.txt']); 
    d28=load([ c28  '.txt']); 
    d29=load([ c29  '.txt']); 
    d30=load([ c30  '.txt']); 
    d31=load([ c31  '.txt']); 
    d32=load([ c32  '.txt']); 
    d33=load([ c33  '.txt']); 
    d34=load([ c34  '.txt']); 
    d35=load([ c35  '.txt']); 
    d36=load([ c36  '.txt']); 
  
    d4_x=d4(:,1:4);
    Xx=[d1 d2 d3 d4_x];
  
    X=[Xx(:,1:2) Xx(:,3:size(Xx,2)) Xx(:,1).^2/10]; 
    X  = sparse(X);
 
    Yy=d4(:,5);
    % Y is 0 and 1, now re-code y as 1 and -1.
    Y=2*Yy-1;
    Y  = sparse(Y);
    
    mr=length(Y); %% mr is the # of observations of group g 
    N=N+mr; 
    
    grp = grp+1;

    %%% Load the W,  %%%
    s=['weight\w_' int2str(i)];
    weight=load([s '.txt']);
    w11=weight(:,1);
    w12=weight(:,2);
    w13=weight(:,3);
    
    ww=sparse(w11,w12,w13,mr,mr);    
    ww=ww-diag(diag(ww));
    
    temp=full(sum(ww')');
    temp2=zeros(length(temp),1);
    for j=1:length(ww)
        if temp(j)~=0
            temp2(j)=1/temp(j);
        else
            temp2(j)=0;
        end
    end
    
    w=diag(temp2)*ww;
    Wr=sparse(w);
 
    clear temp
    clear temp2
 
        dat(grp).yt = Y;
        dat(grp).xt = X;
        dat(grp).wt = Wr;
        dat(grp).unobs = miu(i,:);
     
    end

 k=17;  
 %use results from previous model as initial values.
tb1=   [     0.7744 
    0.0006 
    0.0186 
   -0.5369 
   -0.2288 
   -0.2220 
    0.0892 
   -0.1416 
    0.0339 
   -0.0355 
   -0.0994 
    0.0919 
    0.0472 
    0.0652 
    0.0705 
   -0.1138 
   -0.2267 
    0.0691 
   -0.0099 
   -0.0240 
    0.0374 
   -0.0898 
   -0.0769 
    0.0785 
   -0.1198 
    0.0397 
   -0.0892 
   -0.0008 
   -0.0549 
    0.0830 
    0.0386 
    0.0306 
   -0.0233 
   -0.0297 
    0.6661 
   -6.7391 ]; 

b1=[tb1' -20+20]'; 

%%%%%%%%%%%%%%%%%%% flag is the switch indicating whether the loop is going on or not
flag=1;
%%%%%%%%%%%%%%%%%%% count limit the number of iteration
count=1;
%%%%%%%%%%%%%%%%%%% SSEmatrx records the SSE
SSE_b=zeros(300,1);

while flag==1
    grp = 0;
    for i=1:g
            yt = dat(i).yt;
            wt = dat(i).wt;
            xt = dat(i).xt;
            unobs = dat(i).unobs;
    
            mr=length(yt); 
            
            endog=zeros(mr,D);
            
            for j=1:D;
                euclid=1;
                x0=ones(mr,1)/2;
                     
             while euclid>1e-7
                   
                   xx_new=tanh(xt*b1(1:k+0)+wt*xt(:,1:k)*b1(k+0+1:2*k+0)+wt*x0*b1(2*k+0+1)+b1(2*k+0+2)*ones(mr,1)+exp(b1(2*k+0+3))*miu(i,j)*ones(mr,1));
                   euclid=max(sum((xx_new-x0).*(xx_new-x0)));
                   x0=xx_new;
                end
                
            endog(:, j)=xx_new;
            end
            grp = grp+1;
            dat2(grp).Mg=wt*endog;
    end
    
    %%%% MLE %%%%
    
    options=optimset('TolX',1e-7,'TolFun',1e-7,'MaxFunEvals',5e+3*count,'MaxIter',count*100,'GradObj','on','DerivativeCheck','off','Display','off');
    [bnew,fval(count),exitflag(count),output(count),grad,hessian] = fminunc(@mr_mle_2Nn_Nd2,b1,options);
    bnewall=[bnewall bnew];
   
    %%%%%%%%%%%%% Judge whether it converges
    
    bnewss=bnew;
    bnewss(2*k+0+3)=exp(bnewss(2*k+0+3));
    b1ss=b1;
    b1ss(2*k+0+3)=exp(b1ss(2*k+0+3));
   
    SSE=(bnewss-b1ss)'*(bnewss-b1ss); 
  
    if isnan(SSE)==1
        flag=0;
        elseif (SSE<=1e-4) 
            flag=0;
            elseif count>1000
                flag=0;
    end
    
    %%%%%%%%%%%% Store SSE to SSEmatrx
    if flag==1
        SSE_b(count)=SSE;
        %update b
      
        b1=(bnew-b1)/(ceil(count/2))+b1;
        
        %check whether the number of iteration egxceeds the limit
        count=count+1
    end
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Calculte s.e.

graf=zeros(2*k+0+3,2*k+0+3); 

ch2=0;
for i = 1:g
    yt = dat(i).yt;
    wt = dat(i).wt;
    xt = dat(i).xt;
    unobs = dat(i).unobs;
    Mg = dat2(i).Mg;
    
    [mr junk] = size(xt);   
     
   
    indp_b1f=zeros(mr,1); 
    indp_b2f=zeros(mr,1);
   
    gra_gf=zeros(mr,2*k+0+3);  
    
    for j=1:D
    
        bbf=xt*bnew(1:k+0)+wt*xt(:,1:k)*bnew(k+0+1:2*k+0)+Mg(:,j)*bnew(2*k+0+1)+bnew(2*k+0+2)*ones(mr,1)+exp(bnew(2*k+0+3))*unobs(:,j)*ones(mr,1); %unobs is 1 by D for each group.

        indp_b1f=indp_b1f+ 1./(1+exp(-2*bbf));
      
        indp_b2f=indp_b2f+ 1./(1+exp(2*bbf));
        
        %gradient
        deri=4./(exp(2*bbf)+exp(-2*bbf)+2);
        temp1=diag(deri);
        temp2=inv(eye(mr,mr)-bnew(2*k+0+1)*temp1*wt)*temp1*[xt wt*xt(:,1:k) Mg(:,j) ones(mr,1) unobs(:,j)*ones(mr,1)*exp(bnew(2*k+0+3))];% remember to include the constant term.
       
        temp4=[xt wt*xt(:,1:k) Mg(:,j) ones(mr,1) unobs(:,j)*ones(mr,1)*exp(bnew(2*k+0+3))]+bnew(2*k+0+1)*wt*temp2;  % remember to include the constant term.
          
        gra_gf=gra_gf+temp4;
             
    end

    temp3=indp_b1f/D;
  
    graf1=2*diag((ones(mr,1)+yt)/2-temp3)*gra_gf/D;
 
    graf=graf+graf1'*graf1; 
   
    ch2=ch2+1;
    
end

    msscore = graf/N;     
    se=sqrt(diag(pinv(msscore)/N));
    result=full([bnew se 2*(1-normcdf(abs(bnew./se)))])
    -2*fval
    
